// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===--- BubbleExpandShapes.cpp --- Pass to propagate expand shapes op up -===//
//
// This pass propagates expand_shape operations up the program (and conversely)
// sinks the collapse_shape operations down the program to get the elementwise
// operations into higher dimensionality to get better fusion.
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes"

static llvm::cl::opt<bool> clPropagateCollapseAcrossExpands(
    "iree-dispatch-creation-propagate-collapse-across-expands",
    llvm::cl::desc("Enables change to propagate collapse shapes across expand "
                   "shapes. This flag is meant as a stop-gap solution before "
                   "making this default due to codegen issues."),
    llvm::cl::init(false));
namespace mlir::iree_compiler::DispatchCreation {

#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"

namespace {

struct BubbleUpExpandShapesPass final
    : public impl::BubbleUpExpandShapesPassBase<BubbleUpExpandShapesPass> {
  using Base::Base;
  void runOnOperation() override;
};

// Convert extract_slice(dequant) to dequant(extract_slice)
//
// Because `extract_slice` ops and dequantize-like ops get cloned into regions
// later, it's okay to bubble up through multi-use dequant ops.
struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> {
  using Base::Base;

  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                PatternRewriter &rewriter) const final {
    Value source = sliceOp.getSource();
    auto genericOp = source.getDefiningOp<linalg::GenericOp>();
    if (!genericOp || genericOp->getNumResults() != 1) {
      return rewriter.notifyMatchFailure(
          sliceOp, "expected source to implement `linalg::LinalgOp` and have a "
                   "single result");
    }

    if (!IREE::LinalgExt::isBitExtendOp(genericOp) && !genericOp->hasOneUse()) {
      return rewriter.notifyMatchFailure(
          sliceOp,
          "expected source to be dequantize-like op or have a single use");
    }

    if (!sliceOp.hasUnitStride()) {
      return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
    }

    if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
          return map.isProjectedPermutation();
        })) {
      return rewriter.notifyMatchFailure(
          genericOp,
          "expected generic op to have all projected permutation maps");
    }

    Value replacement;
    linalg::GenericOp swappedOp;
    {
      FailureOr<TilingResult> tilingResult =
          tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
                                                       genericOp->getResult(0));
      assert(succeeded(tilingResult) && "failed to swap extract_slice with op");
      assert(tilingResult->tiledOps.size() == 1);
      replacement = tilingResult->tiledValues[0];
      swappedOp = cast<linalg::GenericOp>(tilingResult->tiledOps[0]);
    }

    // Check if this is a rank-reducing slice, if so we need to fold the unit
    // dimensions of the op.
    // This is necessary because `replaceExtractSliceWithTiledProducer` does not
    // take into account the `extract_slice`'s implicit rank reduction. The
    // operations generated by that function will have any unit dims that were
    // removed by the original `extract_slice`. Folding them away ensures that
    // the types match.
    if (sliceOp.getSourceType().getRank() !=
        sliceOp.getResultType().getRank()) {

      llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
      // Get the indexing map for the result.
      AffineMap resultMap =
          swappedOp.getIndexingMapMatchingResult(swappedOp->getResult(0));
      linalg::ControlDropUnitDims options;
      options.rankReductionStrategy = linalg::ControlDropUnitDims::
          RankReductionStrategy::ExtractInsertSlice;
      options.controlFn = [&](Operation *op) -> SmallVector<unsigned> {
        SmallVector<unsigned> droppedDimsVec;
        for (auto [index, expr] : llvm::enumerate(resultMap.getResults())) {
          if (!droppedDims.test(index)) {
            continue;
          }
          auto dimExpr = cast<AffineDimExpr>(expr);
          droppedDimsVec.push_back(dimExpr.getPosition());
        }
        return droppedDimsVec;
      };
      FailureOr<linalg::DropUnitDimsResult> dropUnitDims =
          linalg::dropUnitDims(rewriter, swappedOp, options);
      assert(succeeded(dropUnitDims) &&
             "failed to drop unit dims of produced operation");
      swappedOp = cast<linalg::GenericOp>(dropUnitDims->resultOp);
      replacement = swappedOp->getResult(0);
    }
    rewriter.replaceOp(sliceOp, replacement);
    return success();
  }
};

/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst,
/// tensor.extract_slice(%init)) even when the linalg.fill has multiple users.
/// Bubbles up tensor.extract_slice when encountered with linalg.fill and the
/// former can be folded away.
struct SwapExtractSliceOfFill final
    : public OpRewritePattern<tensor::ExtractSliceOp> {
  using Base::Base;

  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
                                PatternRewriter &rewriter) const override {
    auto fillOp = extractOp.getSource().getDefiningOp<linalg::FillOp>();
    if (!fillOp)
      return failure();

    auto newExtractOp = tensor::ExtractSliceOp::create(
        rewriter, extractOp.getLoc(), extractOp.getType(),
        fillOp.getOutputs()[0], extractOp.getMixedOffsets(),
        extractOp.getMixedSizes(), extractOp.getMixedStrides());
    rewriter.replaceOpWithNewOp<linalg::FillOp>(
        extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()});
    return success();
  }
};

/// Bubbles a `tensor.expand_shape` op through a `tensor.extract_slice` op. This
/// pattern only gets applied when the `extract_slice` doesn't modify dimensions
/// that are expanded by the `expand_shape` and none of the expanded dimensions
/// are dynamic.
/// TODO: move this upstream with other tensor bubbling patterns.
struct BubbleExpandThroughExtract final
    : public OpRewritePattern<tensor::ExpandShapeOp> {

  using Base::Base;

  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
                                PatternRewriter &rewriter) const override {
    auto extractOp = expandOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
    if (!extractOp) {
      return failure();
    }

    ArrayRef<int64_t> extractSrcShape = extractOp.getSourceType().getShape();
    ArrayRef<int64_t> extractDstShape = extractOp.getResultType().getShape();
    const uint64_t extractSrcRank = extractSrcShape.size();

    // Check that none of the expanded dimensions are dynamic or are sliced by
    // the `extract_slice`.
    const llvm::SmallBitVector droppedDims = extractOp.getDroppedDims();
    const SmallVector<ReassociationIndices, 4> reassoc =
        expandOp.getReassociationIndices();
    int64_t droppedDimCount = 0;
    for (uint64_t i = 0; i < extractSrcRank; ++i) {
      if (droppedDims.test(i)) {
        ++droppedDimCount;
        continue;
      }
      if (reassoc[i - droppedDimCount].size() == 1) {
        continue;
      }
      if (ShapedType::isDynamic(extractSrcShape[i]) ||
          extractSrcShape[i] != extractDstShape[i - droppedDimCount]) {
        return rewriter.notifyMatchFailure(
            extractOp, "Extract modifies the expanded dimension");
      }
    }

    // Construct a reassociation that expands `extract_slice`'s source by
    // combining the reassociation from the `expand_shape` with the dropped dims
    // from the `extract_slice`.
    SmallVector<ReassociationIndices> newReassociation;
    newReassociation.reserve(extractSrcRank);
    int64_t count = 0;
    uint64_t expandedIdx = 0;
    for (uint64_t i = 0; i < extractSrcRank; ++i) {
      if (droppedDims.test(i)) {
        newReassociation.push_back(ReassociationIndices{count++});
      } else {
        int64_t numExpanded = reassoc[expandedIdx++].size();
        newReassociation.push_back(
            llvm::to_vector(llvm::seq(count, count + numExpanded)));
        count += numExpanded;
      }
    }

    const SmallVector<OpFoldResult> oldOffsets = extractOp.getMixedOffsets();
    const SmallVector<OpFoldResult> oldSizes = extractOp.getMixedSizes();
    const SmallVector<OpFoldResult> oldStrides = extractOp.getMixedStrides();

    RankedTensorType expandedType = expandOp.getResultType();
    ArrayRef<int64_t> expandedShape = expandedType.getShape();
    auto zeroAttr = rewriter.getIndexAttr(0);
    auto oneAttr = rewriter.getIndexAttr(1);

    // Find the new offsets/sizes/strides for the `extract_slice `& new expanded
    // shape for the `expand_shape`.
    SmallVector<int64_t> newExpandShape;
    SmallVector<OpFoldResult> newOffsets;
    SmallVector<OpFoldResult> newSizes;
    SmallVector<OpFoldResult> newStrides;
    droppedDimCount = 0;
    for (const auto &[inDim, outDims] : llvm::enumerate(newReassociation)) {
      droppedDimCount += droppedDims.test(inDim);
      if (outDims.size() == 1) {
        newExpandShape.push_back(extractSrcShape[inDim]);
        newOffsets.push_back(oldOffsets[inDim]);
        newSizes.push_back(oldSizes[inDim]);
        newStrides.push_back(oldStrides[inDim]);
        continue;
      }
      for (auto outDim : outDims) {
        int64_t expandedDim = expandedShape[outDim - droppedDimCount];
        assert(ShapedType::isStatic(expandedDim));
        newExpandShape.push_back(expandedDim);
        newOffsets.push_back(zeroAttr);
        newSizes.push_back(rewriter.getIndexAttr(expandedDim));
        newStrides.push_back(oneAttr);
      }
    }

    auto newExpandType =
        RankedTensorType::get(newExpandShape, expandedType.getElementType());
    // The builder can't fail to infer the output_shape because none of
    // the dynamic dimensions are expanded.
    auto newExpand = tensor::ExpandShapeOp::create(
        rewriter, expandOp.getLoc(), newExpandType, extractOp.getSource(),
        newReassociation);

    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
        expandOp, expandedType, newExpand, newOffsets, newSizes, newStrides);
    return success();
  }
};

struct BubbleExpandThroughConcat final
    : public OpRewritePattern<tensor::ExpandShapeOp> {
  using Base::Base;

  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
                                PatternRewriter &rewriter) const override {
    auto concatOp = expandOp.getSrc().getDefiningOp<tensor::ConcatOp>();
    if (!concatOp) {
      return failure();
    }

    // Get concat dimension and input/result types.
    int64_t concatDim = concatOp.getDim();
    SmallVector<ReassociationIndices, 4> reassoc =
        expandOp.getReassociationIndices();
    RankedTensorType expandedType = expandOp.getResultType();
    ArrayRef<int64_t> expandedShape = expandedType.getShape();

    // Find the "new" concat dim and if it is part of an expansion dim
    int64_t newConcatDim = 0;
    for (int64_t i = 0; i < concatDim; ++i) {
      newConcatDim += reassoc[i].size();
    }
    int64_t expandedDimCount = reassoc[concatDim].size();

    SmallVector<int64_t> expandShapeProducts;
    int64_t countExpandedDynamicDims = 0;
    // Loop over all but the outermost expanded dims.
    for (Value input : concatOp.getInputs()) {
      int64_t expandShapeProduct = 1;
      auto inputType = cast<RankedTensorType>(input.getType());
      int64_t inputDim = inputType.getShape()[concatDim];
      // if the input dim is dynamic, we rely on checking the divisibility of
      // the other inputs. Example:
      //   %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<100x100xf32>,
      //   tensor<?x100xf32>) -> tensor<?x100xf32> %1 = tensor.expand_shape %0
      //   [[0, 1], [2]] output_shape [?, 10, 100] : tensor<?x100xf32> into
      //   tensor<?x10x100xf32>
      // The propagation here legal because we know ? is divisible by 10
      // because (? + 100) % 10 == 0.
      if (ShapedType::isDynamic(inputDim) && expandedDimCount > 1) {
        countExpandedDynamicDims++;
        // (A+ B + 10) % 10 == 0 does not imply A % 10 == 0 and
        // B % 10 == 0, so we cannot propagate.
        if (countExpandedDynamicDims > 1) {
          return rewriter.notifyMatchFailure(
              expandOp, "More than one dynamic expanded dim");
        }
      }
      // Check all static input shapes for divisibility.
      for (int64_t j = 1; j < expandedDimCount; ++j) {
        int64_t expDim = expandedShape[newConcatDim + j];
        if (ShapedType::isDynamic(expDim)) {
          // If expanded dim is dynamic and not outermost, we cannot propagate.
          return rewriter.notifyMatchFailure(
              expandOp, "Expanded dim is dynamic and not outermost");
        }
        // Calculate the product of static dims.
        expandShapeProduct *= expDim;
      }
      if (ShapedType::isStatic(inputDim) &&
          !(inputDim % expandShapeProduct == 0)) {
        // ie concat(tensor<6xf32>, tensor<6xf32>) -> expand to
        // tensor<3x2x2xf32> cannot be done because 6 % (2 * 2) != 0.
        return rewriter.notifyMatchFailure(
            expandOp,
            "Input dim is not divisible by the product of inner expanded dims");
      }
      expandShapeProducts.push_back(expandShapeProduct);
    } // else we can always propagate the expand_shape through the concat.

    SmallVector<OpFoldResult> mixedOutputShape = expandOp.getMixedOutputShape();
    SmallVector<int64_t> staticOutputShape =
        llvm::to_vector(expandedType.getShape());
    // Create new expand_shape ops for each input.
    SmallVector<Value> newInputs;
    for (auto [input, product] :
         llvm::zip(concatOp.getInputs(), expandShapeProducts)) {
      auto type = input.getType();
      auto inputType = cast<RankedTensorType>(type);

      AffineExpr concatDimExpr;
      bindSymbols(rewriter.getContext(), concatDimExpr);
      auto divMap = concatDimExpr.floorDiv(product);
      OpFoldResult dimOfr =
          tensor::getMixedSize(rewriter, expandOp.getLoc(), input, concatDim);
      OpFoldResult concatDimValue = affine::makeComposedFoldedAffineApply(
          rewriter, expandOp.getLoc(), divMap, ArrayRef<OpFoldResult>{dimOfr});
      mixedOutputShape[newConcatDim] = concatDimValue;
      staticOutputShape[newConcatDim] =
          mlir::getConstantIntValue(concatDimValue)
              .value_or(ShapedType::kDynamic);
      auto newType =
          RankedTensorType::get(staticOutputShape, inputType.getElementType());
      Value newExpand;
      newExpand =
          tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newType,
                                        input, reassoc, mixedOutputShape);
      newInputs.push_back(newExpand);
    }
    // Create new concat op on expanded inputs.
    auto newConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
                                              newConcatDim, newInputs);
    rewriter.replaceOp(expandOp, newConcat.getResult());
    return success();
  }
};

static bool
isExpandingUnitDims(ArrayRef<ReassociationIndices> reassociationIndices,
                    ArrayRef<int64_t> expandedShape) {
  for (ReassociationIndicesRef reassoc : reassociationIndices) {
    if (reassoc.size() > 1 &&
        llvm::any_of(reassoc, [&expandedShape](int64_t dim) {
          return expandedShape[dim] == 1;
        })) {
      return true;
    }
  }
  return false;
}

// Optimistic check to make sure reshapes are moved if they could block fusion.
static bool isReshapeBlockingFusion(Operation *producer, Operation *consumer) {
  auto isFusableOp = [](Operation *op) {
    if (!op) {
      return false;
    }
    return isa_and_nonnull<linalg::LinalgDialect,
                           IREE::LinalgExt::IREELinalgExtDialect,
                           tensor::TensorDialect>(op->getDialect());
  };
  return isFusableOp(producer) && isFusableOp(consumer);
}

} // namespace

/// If the domain of the operation is being expanded by unit dimensions, check
/// if it's possible to have an infinite loop where the unit dim expansion keeps
/// on propagating infinitely.
static bool canCauseReshapingLoopByExpansion(Operation *producer,
                                             Operation *consumer) {
  auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(consumer);
  if (!expandShapeOp) {
    return false;
  }

  // Check for multiple uses. The producer has at least 1 use: the
  // expand_shape.
  return isExpandingUnitDims(expandShapeOp.getReassociationIndices(),
                             expandShapeOp.getResultType().getShape()) &&
         !llvm::hasSingleElement(producer->getUses());
}

void BubbleUpExpandShapesPass::runOnOperation() {
  MLIRContext *context = &getContext();

  RewritePatternSet bubbleExpandShapePatterns(context);
  linalg::ControlFusionFn bubbleUpExpansionControlFn =
      [&](OpOperand *fusedOperand) {
        Operation *producer = fusedOperand->get().getDefiningOp();
        Operation *consumer = fusedOperand->getOwner();
        if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
          return false;
        }

        if (canCauseReshapingLoopByExpansion(producer, consumer)) {
          return false;
        }

        // Don't reintroduce unit dims via propagating edge unit dim reshapes.
        if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer);
            expandOp &&
            isExpandingUnitDims(expandOp.getReassociationIndices(),
                                expandOp.getResultType().getShape()) &&
            llvm::none_of(
                consumer->getUsers(), [&](Operation *collapseConsumer) {
                  return isReshapeBlockingFusion(producer, collapseConsumer);
                })) {
          return false;
        }
        if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer);
            collapseOp &&
            isExpandingUnitDims(collapseOp.getReassociationIndices(),
                                collapseOp.getSrcType().getShape()) &&
            !isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
                                     consumer)) {
          return false;
        }

        // If producer generic op is elementwise op, bubble up the expand shape
        // past this operation.
        if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
          // If producer generic op is elementwise op, bubble up the expand
          // shape past this operation.
          // If bubbling across reduction ops is enabled, allow all generic ops.
          return (enableBubbleUpExpandShapesAcrossReductionOps ||
                  llvm::all_of(producerGenericOp.getIteratorTypesArray(),
                               linalg::isParallelIterator));
        }

        // Do not bubble up expand shapes across named ops for now.
        if (isa<linalg::LinalgOp>(producer)) {
          return false;
        }

        // Do not push expand shapes down across operations with reduction
        // iterator types.
        // TODO: This condition should be removed.
        if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
          return isa<linalg::GenericOp>(consumerLinalgOp) &&
                 llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
                              linalg::isParallelIterator);
        }
        // Fuse in all other cases.
        return true;
      };

  // Expand/Collapse shape bubbling patterns.
  linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
                                                    bubbleUpExpansionControlFn);
  IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
      bubbleExpandShapePatterns, bubbleUpExpansionControlFn);

  // Extract slice bubbling patterns.
  bubbleExpandShapePatterns.insert<BubbleUpExtract>(context);
  bubbleExpandShapePatterns.insert<SwapExtractSliceOfFill>(context);

  // Add patterns to do some additional cleanup (on top of canonicalizations
  // that can be done later) of reshape ops.
  tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
  bubbleExpandShapePatterns.insert<BubbleExpandThroughExtract>(context);
  bubbleExpandShapePatterns.insert<BubbleExpandThroughConcat>(context);
  tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
                                                     context);
  tensor::CollapseShapeOp::getCanonicalizationPatterns(
      bubbleExpandShapePatterns, context);
  tensor::ExtractSliceOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
                                                      context);
  memref::populateResolveRankedShapedTypeResultDimsPatterns(
      bubbleExpandShapePatterns);

  if (clPropagateCollapseAcrossExpands) {
    tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns);
  }

  GreedyRewriteConfig rewriteConfig;
  rewriteConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit);
  if (failed(applyPatternsGreedily(getOperation(),
                                   std::move(bubbleExpandShapePatterns),
                                   rewriteConfig))) {
    getOperation()->emitOpError("Failed to perform elementwise operations");
    return signalPassFailure();
  }
}

} // namespace mlir::iree_compiler::DispatchCreation
